import gensim
from gensim.models.coherencemodel import CoherenceModel
import gensim.parsing.preprocessing as gpp
from gensim.corpora import Dictionary
from nltk.stem.wordnet import WordNetLemmatizer
from gensim.models.ldamodel import LdaModel
from pprint import pprint
from nltk.tokenize import RegexpTokenizer
import pandas as pd
import os
# Print version of gensim

# Loads the data from the csv file
def append_comments_by_post(in_fn, out_fn):
    df = pd.read_csv(in_fn, dtype={'post_id': str, 'body': str, 'score': int,
                                                'Year':int, 'Month': int, 'Day':int})

    # Remove rows where body is NA

    df = df[df['body'].notna()]

    risk_terms = ["social distance","socially distanced", "conspiracy",
                "open", "conspiracy","reopen", "restrictions",
                "frontline", "vulnerable", "save lives", "lives", "fight",
                "doomer", "forcing", "regulate", "enforce",
                "measures", "dangerous", "harmful", "freedom", "infringe", "fear", 
                "trust", "mask", "vigilance", "panic", "horrifying","symptoms", "scientists", 
                "contagious","non-contagious", "saving", "quarantine", "panic", "damage", "control", "bad news", 
                "strict","science", "pandemic", "herd immunity", "herd", "immunity", "recession", "unconstitutional"]

    # Only keep comments that contain at least one of the risk terms
    df = df[df['body'].str.contains('|'.join(risk_terms))]

    # Only keep comments with a positive score
    df = df[df['score'] > 0]

    # Group by post_id and concatenate the comments
    # num_comments is the number of comments in each group
    # created_timestamp is the timestamp of the first comment in each group

    df = df.groupby('post_id').agg({'body': ' '.join, 
                                    'score': 'sum', 
                                    'created_timestamp': 'min',
                                    'subreddit': 'first',
                                    'post_id': 'count'}).rename(columns={'post_id': 'num_comments'})

    df.to_csv(out_fn)



# Run a simple LDA topic model on the comments


def run_lda(dataset, 
            n_topics, # How many topics to return
            min_count = 20 # How many docs a word must appear in to be included
           ):
    # Split the documents into tokens. This creates a list of words for each document.
    docs = dataset['body'].str.lower().tolist()
    print(f"Preprocessing documents...")
    lemmatizer = WordNetLemmatizer()
    docs = [gpp.preprocess_string(x, filters=[gpp.strip_punctuation,
                                              gpp.strip_multiple_whitespaces,
                                              gpp.strip_numeric,
                                              gpp.remove_stopwords,
                                              gpp.strip_short
                                             ]) for x in docs]
    
    # Lemmatize the words

    dictionary = Dictionary(docs)
    for doc in docs:
        dictionary.add_documents([[lemmatizer.lemmatize(token) for token in doc]])
    dictionary.filter_extremes(no_below=min_count, no_above=0.5)

    corpus = [dictionary.doc2bow(doc) for doc in docs]
    
    print('Number of unique tokens: %d' % len(dictionary))
    print('Number of documents: %d' % len(corpus))
    
    
    # Train LDA model
    print("Running the model...")
    # Set training parameters.
    num_topics = n_topics
    chunksize = 2000
    passes = 20
    iterations = 400
    eval_every = None  # Don't evaluate model perplexity, takes too much time.

    # Make a index to word dictionary.
    #temp = dictionary[0]  # This is only to "load" the dictionary.
    #id2word = dictionary.id2token

    model = LdaModel(
        corpus=corpus,
        id2word=dictionary,
        chunksize=chunksize,
        alpha='auto',
        eta='auto',
        iterations=iterations,
        num_topics=num_topics,
        passes=passes,
        eval_every=eval_every
    )
    
    return (model, corpus, dictionary, dataset)


def prep_data():
    # Check if data has already been processed and out_fn exists
    # If not, run the preprocessing and save the results
    in_fn = '../data/semi_anonymized_comments.csv'
    out_fn = '../data/comments_by_post.csv'
    if not os.path.exists(out_fn):
        append_comments_by_post(in_fn, out_fn)
    df = pd.read_csv(out_fn)
    df = df.reset_index()
    return df



def extract_topics(lda_result, num_topics):
    model, corpus, dictionary, docs = lda_result

    result = []
    for topic_num, words in model.show_topics(formatted=False, num_words=50):
        for word in words:
            result.append({'topic': topic_num,
                    'word': word[0],
                        'probability': word[1]})
            
    pd.DataFrame(result).to_csv(f'../data/lda_topics_{num_topics}_topics.csv', index=False)


def extract_docs(lda_result, num_topics,  num_docs=None):
    model, corpus, dictionary, docs = lda_result
    doc_topics = model.get_document_topics(corpus)
    result = []
    # Loop through the topic distributions for each document
    for doc in doc_topics:
        # Create a temporary dictionary for this document
        curr_result = {}
        # For each topic, add an entry to the dictionary
        for topic_number, weight in doc:
            curr_result[f"topic_{topic_number}"] = weight
        # Then, add the dictionary to our list of dictionaries
        result.append(curr_result)
            
    # Turn the list of dictionaries into a dataframe
    doc_topic_df = pd.DataFrame(result)

    doc_topic_df['original_text'] = docs.body
    doc_topic_df['post_id'] = docs.post_id
    doc_topic_df['subreddit'] = docs.subreddit
    # Group by subreddit, and return the top num_docs documents for each topic number
    doc_topic_df.to_csv(f'../data/doc_topic_df_{num_topics}_topics.csv', index=False)



def get_number_of_topics(n_list = range(3,25)):
    dataset = prep_data(sample_size=5000)
    coherence = []
    for k in n_list:
        print(f'Running with {k} topics')
        model, corpus, dictionary, docs = run_lda(dataset, n_topics=k)
        coherence_model = CoherenceModel(model=model, 
                                texts=docs, 
                                dictionary=dictionary, 
                                coherence='c_v')
        coherence.append((coherence_model.get_coherence(), k))
    print(coherence)


def main(get_n_topics, n_topics):
    if get_n_topics:
        get_number_of_topics()
    else:
        dataset = prep_data()
        lda_result = run_lda(dataset, n_topics=n_topics)
        extract_topics(lda_result, n_topics)
        extract_docs(lda_result, n_topics, num_docs=30)
    

if __name__ == "__main__":
    main(get_n_topics=False, n_topics=10)